classdef CalibrateKrusellWithRobotsEachEconomy < handle
    % Used in production function calibration
    
    properties
        econ_0 % KrusellWithRobotsEconomy for period 0
        econ_1 % KrusellWithRobotsEconomy for period 1
        varnames % variables over which to optimize
        targets % struct of calibration targets
        fix_equipment_structures
        version
    end
    
    properties (SetAccess = protected)
        num_nodes_f
        num_nodes_g
        num_nodes_bins
        num_nodes_phi
        
        nodes_f
        % pre-computed nodes for integration over :math:`f_i(w)`
        weights_f
        % pre-computed weights for integration over :math:`f_i(w)`
        
    end
    
    methods
        function self = CalibrateKrusellWithRobotsEachEconomy(econ_0,econ_1, targets, varnames,varargin)
            valid_num_nodes = @(x) isnumeric(x) && x==floor(x) && ...
                (x>0);
            
            p = inputParser;
            p.addRequired('econ_0',@isobject)
            p.addRequired('econ_1',@isobject)
            p.addRequired('targets',@isstruct)
            p.addRequired('varnames',@iscell)
            p.addParameter('num_nodes_f',10, valid_num_nodes)
            p.addParameter('num_nodes_g',15, valid_num_nodes)
            p.addParameter('num_nodes_bins',10, valid_num_nodes)
            p.addParameter('num_nodes_phi',50, valid_num_nodes)
            p.addParameter('fix_equipment_structures',false, @islogical)
            p.addParameter('version',[], @ischar)
            
            p.parse(econ_0, econ_1, targets, varnames,varargin{:});
            
            self.econ_0 = p.Results.econ_0;
            self.econ_1 = p.Results.econ_1;
            self.targets = p.Results.targets;
            self.varnames = p.Results.varnames;
            self.num_nodes_f = p.Results.num_nodes_f;
            self.num_nodes_g = p.Results.num_nodes_g;
            self.num_nodes_bins = p.Results.num_nodes_bins;
            self.num_nodes_phi = p.Results.num_nodes_phi;
            self.fix_equipment_structures = p.Results.fix_equipment_structures;
            self.version = p.Results.version;
            
            [self.nodes_f, self.weights_f] = ...
                tools.Integration.lgwt(self.num_nodes_f,-1,1);
            
            % ensure that num_nodes are consistent throughout
            self.econ_0.set_num_nodes(self.num_nodes_f, ...
                self.num_nodes_g,self.num_nodes_bins,self.num_nodes_phi);
            self.econ_1.set_num_nodes(self.num_nodes_f, ...
                self.num_nodes_g, ...
                self.num_nodes_bins,self.num_nodes_phi);
        end
        
        function set_num_nodes(self, num_nodes_f, num_nodes_g,num_nodes_bins)
            % Set number of nodes properties for
            % :code:`CalibrateKrusellWithRobotsEconomy`,
            % :code:`KrusellWithRobotsEconomy`, :code:`SkillAndParticipation`
            % and :code:`dens_skill` objects
            %
            % Args:
            %     num_nodes_f (int): number of nodes to use for
            %         Gauss-Legendre integration over :math:`f()`
            %     num_nodes_g (int): number of nodes to use for
            %         Gauss-Legendre integration over :math:`g()`
            %     num_nodes_bins (double): number of nodes to use for
            %         Gauss-Legendre integration over bins of
            %         :math:`f()`
            
            valid_num_nodes = @(x) isnumeric(x) && ...
                x==floor(x) && (x>0);
            
            p = inputParser;
            p.addRequired('num_nodes_f', valid_num_nodes);
            p.addRequired('num_nodes_g', valid_num_nodes);
            p.addRequired('num_nodes_bins', valid_num_nodes);
            
            p.parse(num_nodes_f, num_nodes_g, num_nodes_bins);
            
            self.num_nodes_f = p.Results.num_nodes_f;
            self.num_nodes_g = p.Results.num_nodes_g;
            self.num_nodes_bins = p.Results.num_nodes_bins;
            
            % re-compute nodes and weights
            [self.nodes_f, self.weights_f] = tools.Integration.lgwt(self.num_nodes_f,-1,1);
            
            % set nodes and weights for dist object
            self.econ_0.set_num_nodes(num_nodes_f, num_nodes_g, ...
                num_nodes_bins);
            self.econ_1.set_num_nodes(num_nodes_f, num_nodes_g, ...
                num_nodes_bins);
        end
        
        function [sol,fval,exitflag,output] = calibrate(self, var0, lb, ub)
            % calibrate production function
            %
            % Args:
            %     var0 (struct): struct with initial values for
            %         variables over which to optimize. For
            %         additional variables which can be calibrated see
            %         keyword arguments.
            %     lb (struct): struct with lower-bound values for
            %         variables over which to optimize
            %     ub (struct): struct with upper-bound values for
            %         variables over which to optimize
            %
            
            x0 = tools.translators.s2x(self,var0);
            
            Amat = [];
            b = [];
            Aeq = [];
            beq = [];
            
            lb = tools.translators.s2x(self,lb);
            ub = tools.translators.s2x(self,ub);
            
            obj = @(x) self.obj_ssr(tools.translators.x2s(self,x));
            nonlcon = @(x) self.nonlcon_ssr(tools.translators.x2s(self,x));
            
            fprintf('inp0\n')
            var0
            
            fprintf('constraint at x0\n')
            [c,ceq] = nonlcon(x0)
            
            knitrooptions = optimset( 'Display','iter');
            options_file = 'production_function_calibration.opt';
            [x,fval,exitflag,output] = knitromatlab(obj, x0, Amat,b,Aeq,beq,lb,ub, ...
                nonlcon,[],knitrooptions, ...
                options_file); ...
                
            sol = tools.translators.x2s(self,x);
        end
        
        function out = obj_ssr(self, inp)
            mom = self.compute_moments(inp);
            residuals_tmp = self.moment_equations(inp);
            % see moment_equations for the respective moments
            residuals = residuals_tmp([7,15,12,13,14]);
            out = residuals * residuals';
        end
        
        function [c, ceq] = nonlcon_ssr(self,inp)
            [~, ceq_eq] = self.nonlcon_equilibrium_capital_equal(inp);
            mom_eq_tmp = self.moment_equations(inp);
            mom_eq = mom_eq_tmp([1:6,11,23,24]);
            ceq = [ceq_eq, mom_eq];
            % inequality constraint on labor_income_share, higher
            % share not attainable, much too low labor share
            % without this constraint
            c = [0.6 - mom_eq_tmp(26)];
        end
        
        function out = moment_equations(self, inp)
            % compute moment equations, i.e. vector of deviations between moments
            % and targets
            
            targ = self.targets;
            mom = self.compute_moments(inp);
            out = [...
                log(targ.Y_M_0) - log(mom.Y_M_0),...
                log(targ.Y_R_0) - log(mom.Y_R_0),...
                log(targ.Y_C_0) - log(mom.Y_C_0),...
                log(targ.Y_M_1) - log(mom.Y_M_1),...
                log(targ.Y_R_1) - log(mom.Y_R_1),...
                log(targ.Y_C_1) - log(mom.Y_C_1),...
                log(targ.labor_income_share_0) - log(mom.labor_income_share_0),...
                log(targ.M_share_in_labor_income_0) - log(mom.M_share_in_labor_income_0),...
                log(targ.R_share_in_labor_income_0) - log(mom.R_share_in_labor_income_0),...
                log(targ.labor_share_in_labor_plus_robot_income_0) - log(mom.labor_share_in_labor_plus_robot_income_0),...
                log(targ.K_B_rel_change) - log(mom.K_B_rel_change), ...
                log(targ.alpha) - log(inp.alpha),...
                log(targ.rho) - log(inp.rho),...
                log(targ.sigma) - log(inp.sigma),...
                log(targ.Y_B_rel_change) - log(mom.Y_B_rel_change),...
                log(targ.Y_B_0) - log(mom.Y_B_M_0),...
                log(mom.Y_B_M_0) - log(mom.Y_E_0),...
                log(mom.Y_B_M_0) - log(mom.Y_S_0),...
                log(targ.Y_E_0) - log(mom.Y_E_0),...
                log(targ.Y_S_0) - log(mom.Y_S_0),...
                log(targ.Y_E_1) - log(mom.Y_E_1),...
                log(targ.Y_S_1) - log(mom.Y_S_1),...
                log(mom.Y_B_M_0) - log(mom.Y_E_0),...
                log(mom.Y_B_M_0) - log(mom.Y_S_0),...
                log(targ.robot_share_capital) - log(mom.robot_share_capital),...
                mom.labor_income_share_0,...
                log(mom.Y_B_M_0) - log(0.05)];
        end
        
        function out = compute_moments(self, inp)
            
            % use inp both for inp and varargin
            [inp_0, inp_1] = self.split_inp_two_periods(inp);
            
            [out.Y_M_0, out.Y_R_0, out.Y_C_0, out.Y_B_M_0, out.Y_B_R_0, ...
                out.Y_B_C_0, out.Y_E_0, out.Y_S_0] = self.econ_0.factor_prices(inp_0,inp_0);
            
            [out.Y_M_1, out.Y_R_1, out.Y_C_1, out.Y_B_M_1, out.Y_B_R_1, ...
                out.Y_B_C_1, out.Y_E_1, out.Y_S_1] = self.econ_1.factor_prices(inp_1,inp_1);
            
            out.K_E_1 = inp_1.K_E;
            out.K_S_1 = inp_1.K_S;
            
            out.labor_income_share_0 = ...
                self.econ_0.labor_income_share(inp_0,inp_0);
            out.labor_income_0 = self.econ_0.labor_income(inp_0, ...
                inp_0);
            out.M_share_in_labor_income_0 = inp_0.L_M * out.Y_M_0/out.labor_income_0;
            out.R_share_in_labor_income_0 = inp_0.L_R * out.Y_R_0/ ...
                out.labor_income_0;
            out.K_B_0 = inp_0.K_B_M + inp_0.K_B_R + inp_0.K_B_C;
            out.K_B_1 = inp_1.K_B_M + inp_1.K_B_R + inp_1.K_B_C;
            out.robot_income_0 = inp.K_B_M_0 * out.Y_B_M_0 + ...
                inp.K_B_R_0 * out.Y_B_R_0 + inp.K_B_C_0 * out.Y_B_C_0;
            out.labor_share_in_labor_plus_robot_income_0 = ...
                out.labor_income_0/(out.labor_income_0 + ...
                out.robot_income_0);
            out.K_B_rel_change = (out.K_B_1 - out.K_B_0)/out.K_B_0;
            out.Y_B_rel_change = (out.Y_B_M_1 - out.Y_B_M_0)/out.Y_B_M_0; ...
                out.price_elasticity_robots = out.K_B_rel_change/out.Y_B_rel_change;
            
            % using that marginal products have to be equal for all
            % robots, Y_B = Y_B_M
            
            tau_B = inp.tau_B;
            tau_E = inp.tau_E;
            tau_S = inp.tau_S;
            
            % compute prices using that Y_j = (1+tau_j)*q_j
            out.q_B = out.Y_B_M_0/(1+tau_B);
            out.q_E = out.Y_E_0/(1+tau_E);
            out.q_S = out.Y_S_0/(1+tau_S);
            
            out.tax_rev_robots = tau_B * out.K_B_0 * out.q_B;
            out.tax_rev_structures = tau_S * inp_0.K_S * out.q_S;
            out.tax_rev_equipment = tau_E * inp_0.K_E * out.q_E;
            
            out.robot_share_capital = out.q_B * out.K_B_0...
                /(out.q_B * out.K_B_0 + out.q_E * inp.K_E_0 + out.q_S ...
                * inp.K_S_0);
            
        end
        
        function [c, ceq] = nonlcon_equilibrium_capital_equal(self, inp)
            % non-linear constraints, where :code:`c` are inequality constraints,
            % and :code:`ceq` are equality constraints
            %
            % version which does not solve for labor market
            % clearing. Solving for labor market clearing is not
            % required if factor prices for labor are matched
            % exactly. In that case, aggregate labor supply at the
            % solution is determined. Price levels for capital do not
            % need to be matched. But returns to robots need to be
            % equalized. Assuming that fix_equipment_structures is
            % true. Price levels for other capital are actually
            % determined in this calibration step, based on equal
            % returns to all types of capital.
            
            
            [inp_0, inp_1] = self.split_inp_two_periods(inp);
            
            [Y_M_0, Y_R_0, Y_C_0, Y_B_M_0, Y_B_R_0, ...
                Y_B_C_0, Y_E_0, Y_S_0] = ...
                self.econ_0.factor_prices(inp_0,inp_0);
            
            [Y_M_1, Y_R_1, Y_C_1, Y_B_M_1, Y_B_R_1, ...
                Y_B_C_1, Y_E_1, Y_S_1] = self.econ_1.factor_prices(inp_1,inp_1);
            
            % foc's for capital in period 0
            eq_econ_0 = zeros(1,2);
            eq_econ_0(1) = Y_B_M_0 - Y_B_R_0;
            eq_econ_0(2) = Y_B_M_0 - Y_B_C_0;
            
            % foc's for capital in period 1
            eq_econ_1 = zeros(1,2);
            eq_econ_1(1) = Y_B_M_1 - Y_B_R_1;
            eq_econ_1(2) = Y_B_M_1 - Y_B_C_1;
            
            c = [];
            ceq = [eq_econ_0, eq_econ_1];
            
        end
        
        
        %% helper functions
        
        function [inp_0, inp_1] = split_inp_two_periods(self, inp)
            % split inp into inp for each period. Fields which appear
            % without period indexing are copied to both, inp_0 and inp_1
            %
            % Args:
            %     inp (struct): variables which correspond to both periods
            %
            % Returns:
            %      [inp_0, inp_1] (structs): structs of renamed variables
            %          for period 0 and 1
            %
            
            fields = self.get_fields_both_periods(inp);
            
            % assign to period 0 struct
            inp_0 = inp;
            inp_1 = inp;
            
            for i=1:length(fields)
                field = fields{i};
                [inp_0,inp_1] = self.assign_variable_periods(inp_0,inp_1,inp,field);
            end
            
            if self.fix_equipment_structures
                inp_0.K_E = inp.K_E_0;
                inp_0.K_S = inp.K_S_0;
                inp_1.K_E = inp.K_E_0;
                inp_1.K_S = inp.K_S_0;
            end
            
        end
        
        function [strct_period_0,strct_period_1] = assign_variable_periods(self,strct_period_0,strct_period_1,strct,field)
            % assigns field with suffix to struct_period, and removes
            % original field from struct_period. Example: assigns
            % inp_0.L_M = inp.L_M_0 and removes L_M_0 from inp_0
            
            strct_period_0.(field) = strct.([field,'_0']);
            strct_period_1.(field) = strct.([field,'_1']);
            
            % remove fields assigned from strct from strct_period_0
            % and strct_period_1
            strct_period_0 = rmfield(strct_period_0,[field,'_0']);
            strct_period_0 = rmfield(strct_period_0,[field,'_1']);
            strct_period_1 = rmfield(strct_period_1,[field,'_0']);
            strct_period_1 = rmfield(strct_period_1,[field,'_1']);
        end
        
        function fields = get_fields_both_periods(self, inp)
            % returns the fields in :code:`inp` which appear
            % for both periods
            names = fieldnames(inp);
            period_0_idx = arrayfun(@(x) contains(x,'_0'), names);
            period_1_idx = arrayfun(@(x) contains(x,'_1'), names);
            period_0_names = names(period_0_idx);
            period_1_names = names(period_1_idx);
            
            % get names without period indicator
            period_0_names_stripped = strrep(period_0_names,'_0','');
            period_1_names_stripped = strrep(period_1_names,'_1','');
            
            % get intersection of both sets
            fields = intersect(period_0_names_stripped, ...
                period_1_names_stripped);
            
        end
        
    end
    
end

